{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练Siamese network\n",
    "# 1.构造训练数据集\n",
    "# 2.搭建Siamese 网络\n",
    "# 3.训练\n",
    "# 4.测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "m8L4KXFI7zx5"
   },
   "outputs": [],
   "source": [
    "# 导入相关包\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取测试文件夹下的所有类别\n",
    "# 将后5个类别当成support set 和query\n",
    "\n",
    "# 训练集\n",
    "train_classes = glob.glob('dogImages/train/*')[:-5]\n",
    "# 测试集\n",
    "test_classes = glob.glob('dogImages/test/*')[:-5]\n",
    "\n",
    "# support set 和query\n",
    "query_classes = glob.glob('dogImages/train/*')[-5:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(test_classes),len(query_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 遍历类别下的所有文件\n",
    "train_files = []\n",
    "# 遍历每一类\n",
    "for class_item in train_classes:\n",
    "    # 遍历下面所有文件\n",
    "    train_file_list = glob.glob(class_item + '/*.jpg')\n",
    "    train_files.append(train_file_list)\n",
    "    \n",
    "test_files = []\n",
    "# 遍历每一类\n",
    "for class_item in test_classes:\n",
    "    # 遍历下面所有文件\n",
    "    test_file_list = glob.glob(class_item + '/*.jpg')\n",
    "    test_files.append(test_file_list)\n",
    "\n",
    "query_files = []\n",
    "# 遍历每一类\n",
    "for class_item in query_classes:\n",
    "    # 遍历下面所有文件\n",
    "    query_file_list = glob.glob(class_item + '/*')\n",
    "    query_files.append(query_file_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# query_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造自定义数据集\n",
    "import random\n",
    "from tensorflow.keras.preprocessing import image\n",
    "from tensorflow.keras.applications.resnet import preprocess_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 无限循环提供数据集\n",
    "\n",
    "class SiameseSequence(tf.keras.utils.Sequence ):\n",
    "\n",
    "    def __init__(self, file_list, batch_size, batch_len):\n",
    "        # 文件名列表，第一维度索引为类别\n",
    "        self.file_list = file_list\n",
    "        # 类别数量\n",
    "        self.num_classes = len(file_list)\n",
    "        # 批次\n",
    "        self.batch_size = batch_size\n",
    "        # 因为是模式，需要指定一下一个epoch的bach数量\n",
    "        self.batch_len = batch_len\n",
    "\n",
    "    def __len__(self):\n",
    "        # 返回数据集大小\n",
    "        return self.batch_len\n",
    "    \n",
    "    \n",
    "    def __preprocess_image(self,filename):\n",
    "        # 预处理图片\n",
    "        \n",
    "        img = image.load_img(filename, target_size=(200,200))\n",
    "        img = image.img_to_array(img)\n",
    "        img = preprocess_input(img)\n",
    "        return img\n",
    "    \n",
    "    def __getOneSample(self):\n",
    "        # 获取一个样本\n",
    "        \n",
    "        # 随机选出所有类别里的两类\n",
    "        selected_class = random.sample(range(0, self.num_classes), 2)\n",
    "        \n",
    "        # 从一类中挑出两张图作为A、P\n",
    "        A_P_sample = random.sample( self.file_list[selected_class[0]],2)\n",
    "        # 从另一类中挑出两张图作为N\n",
    "        N_sample = random.sample( self.file_list[selected_class[1]],1)\n",
    "        \n",
    "        # 返回\n",
    "        return A_P_sample[0],A_P_sample[1],N_sample[0]\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # 获取一批数据\n",
    "        \n",
    "        batch_A = []\n",
    "        batch_P = []\n",
    "        batch_N = []\n",
    "        \n",
    "        for i in range(self.batch_size):\n",
    "            A,P,N = self.__getOneSample()\n",
    "            \n",
    "            batch_A.append(self.__preprocess_image(A))\n",
    "            batch_P.append(self.__preprocess_image(P))\n",
    "            batch_N.append(self.__preprocess_image(N))\n",
    "            \n",
    "            \n",
    "        return np.array(batch_A),np.array(batch_P),np.array(batch_N)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 实例化构造数据集\n",
    "traingen = SiameseSequence(train_files,32,100)\n",
    "test_gen = SiameseSequence(test_files,32,10)\n",
    "query_gen = SiameseSequence(query_files,32,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6Dcg5F-k7zyA"
   },
   "outputs": [],
   "source": [
    "\n",
    "def visualize(anchor, positive, negative):\n",
    "    # 显示一组图像\n",
    "    fig = plt.figure(figsize=(9, 9))\n",
    "\n",
    "    def changeImg(img):\n",
    "        max_v = img.max()\n",
    "        min_v = img.min()\n",
    "        img =(img - min_v )/ max_v \n",
    "        return cv2.cvtColor(img,cv2.COLOR_BGR2RGB)\n",
    "        \n",
    "    for i in range(3):\n",
    "        plt.subplot(3, 3, 3*i+1)\n",
    "        plt.imshow(changeImg( anchor[i]))\n",
    "        plt.title('anchor')\n",
    "        plt.axis('off')    \n",
    "        \n",
    "        plt.subplot(3, 3, 3*i+2)\n",
    "        plt.imshow(changeImg( positive[i]))        \n",
    "        plt.title('positive')\n",
    "        plt.axis('off')    \n",
    "        \n",
    "        plt.subplot(3, 3, 3*i+3)\n",
    "        plt.imshow(changeImg( negative[i]))        \n",
    "        plt.title('negative')\n",
    "        plt.axis('off')    \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize(*query_gen[5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造获取特征向量的网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras import applications\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import losses\n",
    "from tensorflow.keras import optimizers\n",
    "from tensorflow.keras import metrics\n",
    "from tensorflow.keras import Model\n",
    "from tensorflow.keras import Sequential\n",
    "from tensorflow.keras.applications import ResNet50\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JePVRkqv7zyB"
   },
   "outputs": [],
   "source": [
    "# 加载预训练模型\n",
    "base_cnn = ResNet50(weights=\"imagenet\", input_shape=(200,200,3), include_top=False)\n",
    "\n",
    "custom_layers = Sequential([\n",
    "    layers.Flatten(),\n",
    "    layers.Dense(512, activation=\"relu\"),\n",
    "    layers.BatchNormalization(),\n",
    "\n",
    "    layers.Dense(256, activation=\"relu\"),\n",
    "    layers.BatchNormalization(),\n",
    "\n",
    "    layers.Dense(256)\n",
    "\n",
    "])\n",
    "output = custom_layers(base_cnn.output)\n",
    "# 打包embedding模型\n",
    "embedding = Model(inputs=base_cnn.input, outputs=output, name=\"Embedding\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EnbDRFpaBGWO"
   },
   "outputs": [],
   "source": [
    "# 查看网络架构，查看可以训练的参数量，Keras默认都是可以训练的，除了BatchNormalization层\n",
    "# base_cnn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "L7dM14iRBDxG"
   },
   "outputs": [],
   "source": [
    "# 迁移学习，冻结conv5_block1_out之前的所有网络层\n",
    "trainable = False\n",
    "for layer in base_cnn.layers:\n",
    "    # conv5_block1_out后的才可以fine tune，其他冻结\n",
    "    if layer.name == \"conv5_block1_out\":\n",
    "        trainable = True\n",
    "    layer.trainable = trainable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 再次检查\n",
    "# base_cnn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "x8jjLEYh7zyC"
   },
   "outputs": [],
   "source": [
    "# https://www.tensorflow.org/api_docs/python/tf/keras/Model#used-in-the-notebooks_1\n",
    "\n",
    "class SiameseModel(Model):\n",
    "\n",
    "    def __init__(self, embedding, margin=0.5):\n",
    "        super(SiameseModel, self).__init__()\n",
    "        \n",
    "        # embedding网络\n",
    "        self.embedding = embedding\n",
    "        # 定义margin\n",
    "        self.margin = margin\n",
    "        self.loss_tracker = metrics.Mean(name=\"loss\")\n",
    "\n",
    "    def call(self, inputs):\n",
    "        # 输入的是图片，返回的是A、P和A、N的距离\n",
    "        anchor_img = inputs[0]\n",
    "        positive_img = inputs[1]\n",
    "        negative_img = inputs[2]\n",
    "        \n",
    "        # 计算embedding\n",
    "        anchor_embedding = self.embedding(anchor_img)\n",
    "        positive_embedding = self.embedding(positive_img)\n",
    "        negative_embedding = self.embedding(negative_img)\n",
    "        \n",
    "        # 计算欧式距离的平方\n",
    "        ap_distance = tf.reduce_sum(tf.square(anchor_embedding - positive_embedding), -1)\n",
    "        an_distance = tf.reduce_sum(tf.square(anchor_embedding - negative_embedding), -1)\n",
    "        \n",
    "        return ap_distance,an_distance\n",
    "\n",
    "    \n",
    "    def train_step(self, data):\n",
    "        # 训练步骤\n",
    "        # 使用梯度带跟踪梯度\n",
    "        with tf.GradientTape() as tape:\n",
    "            # 计算loss\n",
    "            loss = self._compute_triplet_loss(data)\n",
    "        # 计算梯度\n",
    "        gradients = tape.gradient(loss, self.trainable_weights)\n",
    "        # 优化\n",
    "        self.optimizer.apply_gradients( zip(gradients, self.trainable_weights) )\n",
    "        # 更新loss\n",
    "        self.loss_tracker.update_state(loss)\n",
    "        \n",
    "        return {\"loss\": self.loss_tracker.result()}\n",
    "\n",
    "    def test_step(self, data):\n",
    "        # 测试集评估\n",
    "        loss = self._compute_triplet_loss(data)\n",
    "        \n",
    "        self.loss_tracker.update_state(loss)\n",
    "        return {\"loss\": self.loss_tracker.result()}\n",
    "\n",
    "    def _compute_triplet_loss(self, data):\n",
    "        # 计算triplet loss\n",
    "        # 计算距离\n",
    "        ap_distance, an_distance = self.call(data)\n",
    "        # 计算loss\n",
    "        loss = tf.maximum(ap_distance + self.margin - an_distance, 0.0)\n",
    "        return loss\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        return [self.loss_tracker]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 实例化\n",
    "siamese_model = SiameseModel(embedding, margin=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试输出\n",
    "output = siamese_model(traingen[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9uqSzayk7zyD"
   },
   "outputs": [],
   "source": [
    "# 编译模型\n",
    "siamese_model.compile(optimizer=optimizers.Adam(0.0001))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练50个epoch\n",
    "history = siamese_model.fit(traingen,validation_data=test_gen,epochs=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型，注意保存的是提取特征的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding.save('latest_new.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 绘制图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_pd = pd.DataFrame(history.history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 查看损失\n",
    "plt.plot(history_pd['loss'])\n",
    "plt.plot(history_pd['val_loss'])\n",
    "plt.title('Model loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.ylabel('loss')\n",
    "plt.legend(['train set','test set'],loc='upper right')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "孪生网络",
   "provenance": [],
   "toc_visible": true
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
